import torch
from misc.utils import *

class DataLoader:
    def __init__(self, args, is_server=False):
        self.args = args
        self.n_workers = 1
        self.client_id = None

        from torch_geometric.loader import DataLoader
        self.DataLoader = DataLoader

        if is_server and self.args.eval_global:
            self.test = get_data(self.args, mode='test')
            self.te_loader = self.DataLoader (dataset=self.test, batch_size=1, 
                shuffle=False, num_workers=self.n_workers, pin_memory=False)
            self.valid = get_data(self.args, mode='val')
            self.va_loader = self.DataLoader (dataset=self.valid, batch_size=1, 
                shuffle=False, num_workers=self.n_workers, pin_memory=False)

    def switch(self, client_id):
        if not self.client_id == client_id:
            self.client_id = client_id
            self.partition = get_data(self.args, mode='partition', client_id=client_id)
            self.pa_loader = self.DataLoader (dataset=self.partition, batch_size=1, 
                shuffle=False, num_workers=self.n_workers, pin_memory=False)
            
            if self.args.eval_global:
                self.test = get_data(self.args, mode='test', client_id=client_id)
                self.te_loader = self.DataLoader (dataset=self.test, batch_size=1, 
                    shuffle=False, num_workers=self.n_workers, pin_memory=False)
                self.valid = get_data(self.args, mode='val', client_id=client_id)
                self.va_loader = self.DataLoader (dataset=self.valid, batch_size=1, 
                    shuffle=False, num_workers=self.n_workers, pin_memory=False)

def get_data(args, mode, client_id=-1):
    if mode in ['test', 'val']:
        data = torch_load(args.data_path, 
                        f'{args.dataset}/{args.n_clients}/{mode}.pt')['data']
    else:
        data = torch_load(args.data_path, 
            f'{args.dataset}/{args.n_clients}/{args.dist}_{mode}_{client_id}.pt')['client_data']

    return [data]
